{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 构建界面"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import streamlit as st\n",
    "\n",
    "from fileingestor import FileIngestor\n",
    "\n",
    "# Set the title for the Streamlit app\n",
    "st.title(\"Chat with PDF\")\n",
    "\n",
    "# Create a file uploader in the sidebar\n",
    "uploaded_file = st.sidebar.file_uploader(\"Upload File\", type=\"pdf\")\n",
    "\n",
    "if uploaded_file:\n",
    "    file_ingestor = FileIngestor(uploaded_file)\n",
    "    file_ingestor.handlefileandingest()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 加载LLM模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.llms import LlamaCpp\n",
    "from langchain.callbacks.manager import CallbackManager\n",
    "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n",
    "\n",
    "#中文模型yi\n",
    "model_path = 'yi-chat-6B-GGUF/yi-chat-6b.f16.gguf'\n",
    "\n",
    "\n",
    "class Loadllm:\n",
    "    @staticmethod\n",
    "    def load_llm():\n",
    "        callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
    "        # Prepare the LLM\n",
    "\n",
    "        llm = LlamaCpp(\n",
    "            model_path=model_path,\n",
    "            n_gpu_layers=40,\n",
    "            n_batch=512,\n",
    "            n_ctx=2048,\n",
    "            f16_kv=True,  # MUST set to True, otherwise you will run into problem after a couple of calls\n",
    "            callback_manager=callback_manager,\n",
    "            verbose=True,\n",
    "        )\n",
    "\n",
    "        return llm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "3. LangChain集成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import streamlit as st\n",
    "from langchain.document_loaders import PyMuPDFLoader\n",
    "from loadllm import Loadllm\n",
    "from streamlit_chat import message\n",
    "import tempfile\n",
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "from langchain.vectorstores import FAISS\n",
    "from langchain.chains import ConversationalRetrievalChain\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "\n",
    "DB_FAISS_PATH = 'vectorstore/db_faiss'\n",
    "\n",
    "\n",
    "class FileIngestor:\n",
    "    def __init__(self, uploaded_file):\n",
    "        self.uploaded_file = uploaded_file\n",
    "        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)\n",
    "\n",
    "    def handlefileandingest(self):\n",
    "        with tempfile.NamedTemporaryFile(delete=False) as tmp_file:\n",
    "            tmp_file.write(self.uploaded_file.getvalue())\n",
    "            tmp_file_path = tmp_file.name\n",
    "\n",
    "        loader = PyMuPDFLoader(file_path=tmp_file_path)\n",
    "        pages = loader.load()\n",
    "        data = self.text_splitter.split_documents(pages)\n",
    "\n",
    "        # BGE embedding\n",
    "        embeddings = HuggingFaceBgeEmbeddings(model_name='bge-large-zh-v1.5', model_kwargs={\"device\":'cuda'},encode_kwargs={'normalize_embeddings': True})\n",
    "\n",
    "        # FAISS\n",
    "        db = FAISS.from_documents(data, embeddings)\n",
    "        db.save_local(DB_FAISS_PATH)\n",
    "\n",
    "        # Load the language model\n",
    "        llm = Loadllm.load_llm()\n",
    "\n",
    "        # Create a conversational chain\n",
    "        chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=db.as_retriever())\n",
    "\n",
    "        # Function for conversational chat\n",
    "        def conversational_chat(query):\n",
    "            result = chain({\"question\": query, \"chat_history\": st.session_state['history']})\n",
    "            st.session_state['history'].append((query, result[\"answer\"]))\n",
    "            return result[\"answer\"]\n",
    "\n",
    "        # Initialize chat history\n",
    "        if 'history' not in st.session_state:\n",
    "            st.session_state['history'] = []\n",
    "\n",
    "        # Initialize messages\n",
    "        if 'generated' not in st.session_state:\n",
    "            st.session_state['generated'] = [\"Hello ! Ask me(yi-chat-6b) about \" + self.uploaded_file.name ]\n",
    "\n",
    "        if 'past' not in st.session_state:\n",
    "            st.session_state['past'] = [\"Hey !\"]\n",
    "\n",
    "        # Create containers for chat history and user input\n",
    "        response_container = st.container()\n",
    "        container = st.container()\n",
    "\n",
    "        # User input form\n",
    "        with container:\n",
    "            with st.form(key='my_form', clear_on_submit=True):\n",
    "                user_input = st.text_input(\"Query:\", placeholder=\"Talk to PDF data\", key='input')\n",
    "                submit_button = st.form_submit_button(label='Send')\n",
    "\n",
    "            if submit_button and user_input:\n",
    "                output = conversational_chat(user_input)\n",
    "                st.session_state['past'].append(user_input)\n",
    "                st.session_state['generated'].append(output)\n",
    "\n",
    "        # Display chat history\n",
    "        if st.session_state['generated']:\n",
    "            with response_container:\n",
    "                for i in range(len(st.session_state['generated'])):\n",
    "                    message(st.session_state[\"past\"][i], is_user=True, key=str(i) + '_user', avatar_style=\"big-smile\")\n",
    "                    message(st.session_state[\"generated\"][i], key=str(i), avatar_style=\"thumbs\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
